import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.sparse import mm
from torch_geometric.nn import SAGEConv, GCNConv
from torch_geometric.utils import degree, to_undirected
from torch_geometric.utils import to_scipy_sparse_matrix, to_dense_adj
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import numpy as np
import scipy.sparse as sp
import os.path as osp
from typing import Optional
from .hogrl_utils import *

class Layer_AGG(nn.Module):
    def __init__(self, in_feat, out_feat, drop_rate=0.6,weight=1,num_layers =2,layers_tree=2):
        super(Layer_AGG, self).__init__()
        self.drop_rate = drop_rate
        self.weight = weight
        self.num_layers = num_layers
        self.layers_tree = layers_tree
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            in_channels = in_feat if i==0 else out_feat
            self.convs.append(SAGEConv(in_channels,out_feat))
        self.conv_tree = nn.ModuleList()
        self.gating_networks = nn.ModuleList()
        for i in range(0,layers_tree):
            self.conv_tree.append(SAGEConv(in_feat,out_feat))
            self.gating_networks.append(nn.Linear(out_feat, 1))
        self.bias = nn.Parameter(torch.zeros(layers_tree))  
    
    
    
    def forward(self, x, edge_index):
        h = x
        layer_outputs = []
        
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index[0])
            if i != self.num_layers - 1:  # No activation and dropout on the last layer
                x = F.relu(x)
                x = F.dropout(x, p=self.drop_rate, training=self.training)
                
        for i in range(0,self.layers_tree):
            temp = self.conv_tree[i](h,edge_index[1][i])
            temp = F.relu(temp)
            temp = F.dropout(temp,p=self.drop_rate,training=self.training)
            layer_outputs.append(temp)
        # print(layer_outputs[0].shape)

        weighted_sums = [self.gating_networks[i](layer_outputs[i]) for i in range(self.layers_tree)]
        
        # print(weighted_sums[0].shape)
        
        alpha = F.softmax(torch.stack(weighted_sums, dim=-1), dim=-1)

        # print(alpha.shape)
        x_tree = torch.zeros_like(layer_outputs[0])  
        for i in range(self.layers_tree):
        
            weight = alpha[:, :, i]  
            x_tree += layer_outputs[i] * weight

        return x+self.weight*x_tree
    
class multi_HOGRL_Model(nn.Module):
    def __init__(self, in_feat, out_feat, relation_nums=3, hidden=32, 
                 drop_rate=0.6, weight=1, num_layers=2, layers_tree=2,
                 temperature=0.5, dataset="YelpChi"):
        super(multi_HOGRL_Model, self).__init__()
        self.relation_nums = relation_nums
        self.drop_rate = drop_rate
        self.weight = weight
        self.layers_tree = layers_tree
        self.tau = temperature
        self.dataset = dataset
        
        # HOGRL原有的层
        for i in range(relation_nums):
            setattr(self, 'Layers'+str(i), 
                   Layer_AGG(in_feat, hidden, self.drop_rate, 
                           self.weight, num_layers, self.layers_tree))
        self.linear = nn.Linear(hidden*relation_nums, out_feat)
        
        # POT loss相关
        # 添加投影头
        num_proj_hidden = hidden
        self.fc1 = nn.Linear(hidden*relation_nums, num_proj_hidden)  # 第一个投影层
        self.fc2 = nn.Linear(num_proj_hidden, hidden*relation_nums)  # 第二个投影层  # 第二个投影层
        self.pot_loss_func = nn.BCEWithLogitsLoss()



    def forward(self, x, edge_index):
        # HOGRL原有的前向传播
        layer_outputs = []
        for i in range(self.relation_nums):
            layer_output = getattr(self, 'Layers' + str(i))(x, edge_index[i])
            layer_outputs.append(layer_output)
        x_temp = torch.cat(layer_outputs, dim=1)

        x = self.linear(x_temp)
        x = F.log_softmax(x, dim=1)
        
        
        return x, x_temp


    def projection(self, z: torch.Tensor) -> torch.Tensor:
        z = F.elu(self.fc1(z))
        return self.fc2(z)


    def sim(self, z1: torch.Tensor, z2: torch.Tensor):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())


    def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor):
        f = lambda x: torch.exp(x / self.tau)
        refl_sim = f(self.sim(z1, z1))
        between_sim = f(self.sim(z1, z2))
        return -torch.log(
            between_sim.diag() / 
            (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag())
        )


    def loss(self, z1: torch.Tensor, z2: torch.Tensor, mean: bool = True):
        h1 = self.projection(z1)
        h2 = self.projection(z2)
        l1 = self.semi_loss(h1, h2)
        l2 = self.semi_loss(h2, h1)
        ret = (l1 + l2) * 0.5
        ret = ret.mean() if mean else ret.sum()
        return ret 
    

    def pot_loss(self, z1: torch.Tensor, z2: torch.Tensor, x, edge_index, edge_index_1: torch.Tensor, local_changes, drop_rate: float, node_list = None, A_upper=None, A_lower=None):
        # 获取设备和度信息
        deg = degree(to_undirected(edge_index)[1]).cpu().numpy()

        
        # 构建邻接矩阵
        A = to_scipy_sparse_matrix(edge_index).tocsr()

        
        A_tilde = A + sp.eye(A.shape[0])

        start = time.time()
        # 如果没有预计算的边界，计算边界
        if A_upper is None:
            # 计算上界
            degs_tilde = deg + 1
            max_delete = np.maximum(degs_tilde.astype("int") - 2, 0)
            max_delete = np.minimum(max_delete, np.round(local_changes).astype("int"))
            sqrt_degs_tilde_max_delete = 1 / np.sqrt(degs_tilde - max_delete)
            A_upper = sqrt_degs_tilde_max_delete * sqrt_degs_tilde_max_delete[:, None]
            A_upper = np.where(A_tilde.toarray() > 0, A_upper, np.zeros_like(A_upper))
            A_upper = np.float32(A_upper)

            # 计算下界
            new_edge_index, An = gcn_norm(edge_index, num_nodes=A.shape[0])
            An = to_dense_adj(new_edge_index, edge_attr=An)[0].cpu().numpy()
            A_lower = np.zeros_like(An)
            A_lower[np.diag_indices_from(A_lower)] = np.diag(An)
            A_lower = np.float32(A_lower)
            
            # 保存边界
            bounds_dir = '/data/hali/KDD/antifraud/data/amazon/bounds'
            os.makedirs(bounds_dir, exist_ok=True)  # 创建目录（如果不存在）
            upper_lower_file = osp.join(bounds_dir, f"{self.dataset}_{drop_rate}_upper_lower.pkl")
            torch.save((A_upper, A_lower), upper_lower_file)
            
        # 获取节点列表
        N = len(node_list)
        
        # 转换边界为张量
        A_upper_tensor = torch.tensor(A_upper[node_list][:,node_list], device=z1.device).to_sparse()
        A_lower_tensor = torch.tensor(A_lower[node_list][:,node_list], device=z1.device).to_sparse()
        
        # 为每个关系计算POT loss，共用同一个上下界
        total_pot_loss = 0
        for i in range(self.relation_nums):
            layer = getattr(self, f'Layers{i}')
            # SAGEConv使用lin_l和lin_r而不是lin
            W1 = layer.convs[0].lin_l.weight.t() 
            b1 = layer.convs[0].lin_l.bias
            W2 = layer.convs[1].lin_l.weight.t() 
            b2 = layer.convs[1].lin_l.bias
            gcn_weights = [W1, b1, W2, b2]
            # 计算当前关系的边界
            XW = layer.convs[0].lin_l(x)[node_list]  
            H = F.relu(layer.convs[0](x, edge_index))
            HW = layer.convs[1].lin_l(H)[node_list] 
            
            # 计算预激活边界
            W_1 = XW
            z1_U = mm((A_upper_tensor + A_lower_tensor) / 2, W_1) + mm((A_upper_tensor - A_lower_tensor) / 2, torch.abs(W_1)) + b1
            z1_L = mm((A_upper_tensor + A_lower_tensor) / 2, W_1) - mm((A_upper_tensor - A_lower_tensor) / 2, torch.abs(W_1)) + b1
            
            W_2 = HW
            z2_U = mm((A_upper_tensor + A_lower_tensor) / 2, W_2) + mm((A_upper_tensor - A_lower_tensor) / 2, torch.abs(W_2)) + b2
            z2_L = mm((A_upper_tensor + A_lower_tensor) / 2, W_2) - mm((A_upper_tensor - A_lower_tensor) / 2, torch.abs(W_2)) + b2
            
            # 计算CROWN权重

            z2_norm = F.normalize(z2)

            z2_sum = z2_norm.sum(axis=0)

            
            Wcl = z2_norm * (N / (N-1)) - z2_sum / (N - 1)

            
            # 获取CROWN权重
            alpha = 0  # 使用ReLU激活函数
            W_tilde_1, b_tilde_1, W_tilde_2, b_tilde_2 = get_crown_weights(z1_L, z1_U, z2_L, z2_U, alpha, gcn_weights, Wcl)
            # 计算POT分数
            start = time.time()
            XW_tilde = (x[node_list,None,:] @ W_tilde_1[:,:,None]).view(-1,1)
            edge_index_ptb_sl, An_ptb = gcn_norm(edge_index_1, num_nodes=A.shape[0])
            An_ptb = torch.sparse_coo_tensor(edge_index_ptb_sl, An_ptb, size=(A.shape[0],A.shape[0])).index_select(0,torch.tensor(node_list).to(z1.device)).index_select(1,torch.tensor(node_list).to(z1.device))
            H_tilde = mm(An_ptb, XW_tilde) + b_tilde_1.view(-1,1)
            pot_score = mm(An_ptb, H_tilde) + b_tilde_2.view(-1,1)
            pot_score = pot_score.squeeze()


            # 计算当前关系的POT loss
            target = torch.zeros(pot_score.shape, device=z1.device) + 1
            current_pot_loss = self.pot_loss_func(pot_score, target)
            total_pot_loss += current_pot_loss
            break
        
        # 返回所有关系的平均POT loss
        return total_pot_loss / self.relation_nums

    def get_embedding(self, x, edge_index):
        """获取节点嵌入"""
        _, h, _ = self.forward(x, edge_index)
        return h

class Graphsage(nn.Module):
    def __init__(self, in_feat,out_feat):
        super(Graphsage, self).__init__()
        self.conv1 = SAGEConv(in_feat, out_feat)
        self.conv2 = SAGEConv(out_feat, out_feat)
        # self.conv1 = GCNConv(in_feat, out_feat)
        # self.conv2 = GCNConv(out_feat, out_feat)
        self.linear = nn.Linear(out_feat,2)


    def forward(self,x,edge_index):
        x = self.conv1(x,edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x,edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.linear(x)
        x = F.log_softmax(x,dim=1)
        return x